{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from processing.load_datasets import load_datasets\n",
    "from processing.configs import COLORS_4\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "from umap import UMAP\n",
    "from trimap import TRIMAP\n",
    "from pacmap import PaCMAP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load RTS data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load datasets\n",
    "datasets_names = [\"rts\", \"pdl\", \"ioc\", \"mjf\"]   \n",
    "datasets = load_datasets(datasets_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computational efficiency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = [100, 1000, 5000, 10000, 50000, 100000]\n",
    "\n",
    "methods = [TSNE(), UMAP(), TRIMAP(), PaCMAP()]\n",
    "n_runs = 10\n",
    "\n",
    "X_rts = datasets[\"rts\"] # Run on RTS only, as it's the largest dataset and we want to see how the methods scale with the number of samples\n",
    "rng = np.random.default_rng(42)\n",
    "\n",
    "times = {}\n",
    "for n in N:\n",
    "    X = X_rts[rng.choice(X_rts.shape[0], n, replace=False)]\n",
    "    times_at_n = {}\n",
    "    print(f\"====== Running for n={n} ======\")\n",
    "    for method in methods:\n",
    "        print(f\"Running {method.__class__.__name__}...\")\n",
    "        times_for_method = []\n",
    "        for i in range(n_runs):\n",
    "            start = time.time()\n",
    "            X_embedded = method.fit_transform(X)\n",
    "            end = time.time()\n",
    "            times_for_method.append(end - start)\n",
    "        times_at_n[method.__class__.__name__] = times_for_method\n",
    "    times[n] = times_at_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "for i, method in enumerate(methods):\n",
    "    means = np.array([np.mean(times[n][method.__class__.__name__]) for n in N])\n",
    "    ax.errorbar(N, means, label=method.__class__.__name__, fmt=\"x-\", color=COLORS_4[i])\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xlabel(\"Number of samples\", fontsize = 12, fontweight = \"bold\")\n",
    "ax.set_ylabel(\"Time (s)\", fontsize = 12, fontweight = \"bold\")\n",
    "ax.legend()\n",
    "ax.grid(True, which=\"major\", ls=\"--\", linewidth=0.5)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"images/computational_comparison.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.metrics import pairwise_distances\n",
    "from itertools import combinations\n",
    "\n",
    "def compute_pairwise_spearman(matrix1, matrix2):\n",
    "    \"\"\"\n",
    "    Computes Spearman correlation between two pairwise distance matrices.\n",
    "    \"\"\"\n",
    "    d1 = matrix1[np.triu_indices_from(matrix1, k=1)]\n",
    "    d2 = matrix2[np.triu_indices_from(matrix2, k=1)]\n",
    "    corr, _ = spearmanr(d1, d2)\n",
    "    return corr\n",
    "\n",
    "def stability_spearman(dr_model, data, sample_size=500, n_runs=10, random_state=None, method_name=\"\"):\n",
    "    rng = np.random.default_rng(random_state)\n",
    "    idx = rng.choice(len(data), size=sample_size, replace=False)\n",
    "    sample = data[idx]\n",
    "\n",
    "    embeddings = []\n",
    "    for _ in tqdm(range(n_runs)):\n",
    "        if method_name == \"TriMap\":\n",
    "            np.random.seed(int(rng.integers(1e6)))  # Seed NumPy RNG\n",
    "            model = dr_model()\n",
    "        else:\n",
    "            model = dr_model(random_state=int(rng.integers(1e6)))  # Force int\n",
    "        embeddings.append(model.fit_transform(sample))\n",
    "\n",
    "\n",
    "    distance_matrices = [pairwise_distances(embed) for embed in embeddings]\n",
    "    correlations = [\n",
    "        compute_pairwise_spearman(dm1, dm2)\n",
    "        for i, dm1 in enumerate(distance_matrices)\n",
    "        for j, dm2 in enumerate(distance_matrices) if j > i\n",
    "    ]\n",
    "    return np.mean(correlations), np.std(correlations)\n",
    "\n",
    "# Define DR model wrappers\n",
    "def tsne_model(random_state=None):\n",
    "    return TSNE(n_components=2, perplexity=30, init='random', random_state=random_state)\n",
    "\n",
    "def umap_model(random_state=None):\n",
    "    return UMAP(n_components=2, random_state=random_state)\n",
    "\n",
    "def trimap_model(random_state=None):\n",
    "    return TRIMAP(n_dims=2)\n",
    "\n",
    "def pacmap_model(random_state=None):\n",
    "    return PaCMAP(n_components=2, random_state=random_state)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute stability for each method\n",
    "N = 10000\n",
    "n_runs = 10\n",
    "\n",
    "rows = []\n",
    "for dataset_name in datasets_names:\n",
    "    X = datasets[dataset_name]\n",
    "    print(f\"Computing stability for {dataset_name} with sample size {N} and {n_runs} runs...\")\n",
    "\n",
    "    results = {\n",
    "        \"TSNE\": stability_spearman(tsne_model, X, sample_size=N, n_runs=n_runs),\n",
    "        \"UMAP\": stability_spearman(umap_model, X, sample_size=N, n_runs=n_runs),\n",
    "        \"TRIMAP\": stability_spearman(trimap_model, X, sample_size=N, n_runs=n_runs),\n",
    "        \"PACMAP\": stability_spearman(pacmap_model, X, sample_size=N, n_runs=n_runs),\n",
    "    }\n",
    "\n",
    "    for method, (mean_stability, std_dev) in results.items():\n",
    "        rows.append({\n",
    "            \"Method\": method,\n",
    "            \"Mean Stability\": mean_stability,\n",
    "            \"Std Dev\": std_dev,\n",
    "            \"Dataset\": dataset_name\n",
    "        })\n",
    "\n",
    "results_df = pd.DataFrame(rows, columns=[\"Method\", \"Mean Stability\", \"Std Dev\", \"Dataset\"])\n",
    "results_df.to_csv(\"results/stability_results.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "methods = [\"TSNE\", \"UMAP\", \"TRIMAP\", \"PACMAP\"]\n",
    "x = np.arange(len(datasets_names))\n",
    "width = 0.18\n",
    "\n",
    "for i, method in enumerate(methods):\n",
    "    means = results_df[results_df[\"Method\"] == method][\"Mean Stability\"].values\n",
    "    stds = results_df[results_df[\"Method\"] == method][\"Std Dev\"].values\n",
    "    ax.bar(x + i*width - 1.5*width, means, width, yerr=stds, label=method, capsize=4, color=COLORS_4[i], edgecolor='black')\n",
    "\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels([d.upper() for d in datasets_names], fontsize=12, fontweight=\"bold\")\n",
    "ax.set_ylabel(\"Mean Stability\", fontsize=12, fontweight=\"bold\")\n",
    "ax.legend(loc=\"upper center\", bbox_to_anchor=(0.5, 1.1), ncols=4, fontsize = 14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"images/stability_comparison.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dr_eval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
